
import torch

#  Maximum Softmax Probability Detector
def msp(dataset_in, dataset_out, net, device):
    dataset_out_len = len(dataset_out.test_loader.dataset)
    dataset_in_len = len(dataset_in.test_loader.dataset)

    with torch.no_grad():
        pred = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
        y = torch.zeros_like(pred).to(device)
        index = 0
        # Test OOD Dataset
        for batch_idx, (data, labels) in enumerate(dataset_out.test_loader):
            data = data.to(device)
            labels = labels.to(device)

            out = net(data)
            smax = torch.nn.functional.softmax(out, dim=1).max(axis=1)[0]

            pred[index: index + data.shape[0]] = smax
            y[index: index + data.shape[0]] = torch.ones_like(labels).to(device)
            index += data.shape[0]

        # Test in distribution
        for batch_idx, (data, labels) in enumerate(dataset_in.test_loader):
            data = data.to(device)
            labels = labels.to(device)

            out = net(data)
            smax = torch.nn.functional.softmax(out, dim=1).max(axis=1)[0]

            pred[index: index + data.shape[0]] = smax
            index += data.shape[0]

    labels = y.cpu().numpy()
    pred = -pred.cpu().numpy()
    return labels, pred